/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Template for a double-buffered threadblock-scoped GEMM kernel.
*/

#pragma once

#include "cutlass/cutlass.h"

#include "cutlass/aligned_buffer.h"
#include "cutlass/arch/memory.h"
#include "cutlass/array.h"
#include "cutlass/array_planar_complex.h"
#include "cutlass/functional.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/numeric_types.h"

#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/threadblock/mma_planar_complex_base.h"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace gemm {
namespace threadblock {

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Structure to compute the matrix product targeting CUDA cores and SIMT math
/// instructions.
template <
        /// Size of the Gemm problem - concept: gemm::GemmShape<>
        typename Shape_,
        /// Iterates over tiles of A operand in global memory
        //  (concept: ReadableTileIterator | ForwardTileIterator |
        //  MaskedTileIterator)
        typename IteratorA_,
        /// Iterates over tiles of A operand in shared memory
        /// (concept: WriteableTileIterator | RandomAccessTileIterator)
        typename SmemIteratorA_,
        /// Cache operation for operand A
        cutlass::arch::CacheOperation::Kind CacheOpA,
        /// Iterates over tiles of B operand in global memory
        //  (concept: ReadableTileIterator | ForwardTileIterator |
        //  MaskedTileIterator)
        typename IteratorB_,
        /// Iterates over tiles of B operand in shared memory
        /// (concept: WriteableTileIterator | RandomAccessTileIterator)
        typename SmemIteratorB_,
        /// Cache operation for operand B
        cutlass::arch::CacheOperation::Kind CacheOpB,
        /// Data type of accumulator matrix
        typename ElementC_,
        /// Data type of accumulator matrix
        typename LayoutC_,
        /// Policy describing tuning details (concept: MmaPolicy)
        typename Policy_,
        /// Number of stages,
        int Stages,
        /// Transformation applied to A
        ComplexTransform TransformA = ComplexTransform::kNone,
        /// Transformation applied to B
        ComplexTransform TransformB = ComplexTransform::kNone>
class MmaPlanarComplexMultistage
        : public MmaPlanarComplexBase<Shape_, Policy_, Stages> {
public:
    ///< Base class
    using Base = MmaPlanarComplexBase<Shape_, Policy_, Stages>;

    ///< Size of the Gemm problem - concept: gemm::GemmShape<>
    using Shape = Shape_;

    ///< Iterates over tiles of A operand in global memory
    using IteratorA = IteratorA_;

    ///< Iterates over tiles of B operand in global memory
    using IteratorB = IteratorB_;

    ///< Data type of accumulator matrix
    using ElementC = ElementC_;

    ///< Layout of accumulator matrix
    using LayoutC = LayoutC_;

    ///< Policy describing tuning details
    using Policy = Policy_;

    ///< Archtecture tag
    using ArchTag = arch::Sm80;

    using SmemIteratorA = SmemIteratorA_;
    using SmemIteratorB = SmemIteratorB_;

    static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA;
    static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB;

    /// Transformation applied to A
    static ComplexTransform const kTransformA = TransformA;

    /// Transformation applied to B
    static ComplexTransform const kTransformB = TransformB;

    //
    // Dependent types
    //

    /// Fragment of accumulator tile
    using FragmentC =
            ArrayPlanarComplex<typename Policy::Operator::FragmentC::Element,
                               Policy::Operator::FragmentC::kElements>;

    /// Warp-level Mma
    using Operator = typename Policy::Operator;

    /// Internal structure exposed for introspection.
    struct Detail {
        static_assert(
                Base::kWarpGemmIterations > 1,
                "The pipelined structure requires at least two warp-level "
                "GEMM operations.");

        /// Number of LDGSTS instructions to load one stage of operand A
        static int const TBLDGSTSIterationsA =
                IteratorA::ThreadMap::Iterations::kCount;

        /// Number of LDGSTS instructions to load one stage of operand B
        static int const TBLDGSTSIterationsB =
                IteratorB::ThreadMap::Iterations::kCount;

        /// Number of stages
        static int const kStages = Stages;

        /// Number of LDGSTS instructions to load on group of operand A
        static int const kAccessesPerGroupA =
                (TBLDGSTSIterationsA + Base::kWarpGemmIterations - 1) /
                Base::kWarpGemmIterations;

        /// Number of LDGSTS instructions to load on group of operand B
        static int const kAccessesPerGroupB =
                (TBLDGSTSIterationsB + Base::kWarpGemmIterations - 1) /
                Base::kWarpGemmIterations;
    };

private:
    using WarpFragmentA = typename Operator::FragmentA;
    using WarpFragmentB = typename Operator::FragmentB;

private:
    //
    // Data members
    //

    /// Iterator to write threadblock-scoped tile of A operand to shared memory
    SmemIteratorA smem_iterator_A_;

    /// Iterator to write threadblock-scoped tile of B operand to shared memory
    SmemIteratorB smem_iterator_B_;

public:
    /// Construct from tensor references
    CUTLASS_DEVICE
    MmaPlanarComplexMultistage(
            ///< Shared storage needed for internal use by threadblock-scoped
            ///< GEMM
            typename Base::SharedStorage& shared_storage,
            ///< ID within the threadblock
            int thread_idx,
            ///< ID of warp
            int warp_idx,
            ///< ID of each thread within a warp
            int lane_idx)
            : Base(shared_storage, thread_idx, warp_idx, lane_idx),
              smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
              smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
        // Compute warp location within threadblock tile by mapping the warp_id
        // to three coordinates:
        //   _m: the warp's position within the threadblock along the M
        //   dimension _n: the warp's position within the threadblock along the
        //   N dimension _k: the warp's position within the threadblock along
        //   the K dimension

        int warp_idx_mn =
                warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
        int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);

        int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
        int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;

        // Add per-warp offsets in units of warp-level tiles
        this->warp_tile_iterator_A_.add_tile_offset(
                {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
        this->warp_tile_iterator_B_.add_tile_offset(
                {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
    }

private:
    CUTLASS_DEVICE
    void copy_tiles_and_advance(IteratorA& iterator_A_real,
                                IteratorA& iterator_A_imag,

                                IteratorB& iterator_B_real,
                                IteratorB& iterator_B_imag,

                                int group_start_A = 0, int group_start_B = 0) {
        iterator_A_real.set_iteration_index(group_start_A *
                                            IteratorA::kAccessesPerVector);
        iterator_A_imag.set_iteration_index(group_start_A *
                                            IteratorA::kAccessesPerVector);
        this->smem_iterator_A_.set_iteration_index(group_start_A);

        // LDGSTS for operand A
        CUTLASS_PRAGMA_UNROLL
        for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) {
            typename IteratorA::AccessType* dst_ptr =
                    reinterpret_cast<typename IteratorA::AccessType*>(
                            this->smem_iterator_A_.get());

            int const kSrcBytes =
                    sizeof_bits<typename IteratorA::Element>::value *
                    IteratorA::ThreadMap::kElementsPerAccess /
                    IteratorA::kAccessesPerVector / 8;

            CUTLASS_PRAGMA_UNROLL
            for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
                auto gmem_ptr_real = iterator_A_real.get();
                auto gmem_ptr_imag = iterator_A_imag.get();

                bool pred_guard = iterator_A_real.valid();
                cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
                        dst_ptr + v, gmem_ptr_real, pred_guard);
                cutlass::arch::cp_async<kSrcBytes, kCacheOpA>(
                        dst_ptr + v +
                                (Base::SharedStorage::kImaginaryStrideA /
                                 IteratorA::ThreadMap::kElementsPerAccess),
                        reinterpret_cast<char const*>(gmem_ptr_imag),
                        pred_guard);

                ++iterator_A_real;
                ++iterator_A_imag;
            }

            ++this->smem_iterator_A_;
        }

        iterator_B_real.set_iteration_index(group_start_B *
                                            IteratorB::kAccessesPerVector);
        iterator_B_imag.set_iteration_index(group_start_B *
                                            IteratorB::kAccessesPerVector);
        this->smem_iterator_B_.set_iteration_index(group_start_B);

        // LDGSTS for operand B
        CUTLASS_PRAGMA_UNROLL
        for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) {
            typename IteratorB::AccessType* dst_ptr =
                    reinterpret_cast<typename IteratorB::AccessType*>(
                            this->smem_iterator_B_.get());

            int const kSrcBytes =
                    sizeof_bits<typename IteratorB::Element>::value *
                    IteratorB::ThreadMap::kElementsPerAccess /
                    IteratorB::kAccessesPerVector / 8;

            CUTLASS_PRAGMA_UNROLL
            for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
                auto gmem_ptr_real = iterator_B_real.get();
                auto gmem_ptr_imag = iterator_B_imag.get();

                bool pred_guard = iterator_B_real.valid();
                cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
                        dst_ptr + v, gmem_ptr_real, pred_guard);
                cutlass::arch::cp_async<kSrcBytes, kCacheOpB>(
                        dst_ptr + v +
                                (Base::SharedStorage::kImaginaryStrideB /
                                 IteratorB::ThreadMap::kElementsPerAccess),
                        reinterpret_cast<char const*>(gmem_ptr_imag),
                        pred_guard);

                ++iterator_B_real;
                ++iterator_B_imag;
            }
            ++this->smem_iterator_B_;
        }
    }

    CUTLASS_DEVICE
    void warp_mma_planar_complex(Operator& warp_mma, FragmentC& accum,
                                 WarpFragmentA const& real_A,
                                 WarpFragmentA const& imag_A,
                                 WarpFragmentB const& real_B,
                                 WarpFragmentB const& imag_B) {
        cutlass::negate<Array<typename WarpFragmentB::Element,
                              WarpFragmentB::kElements>>
                neg_op_B;

        WarpFragmentB neg_real_B = neg_op_B(real_B);
        WarpFragmentB neg_imag_B = neg_op_B(imag_B);

        warp_mma(accum.real, real_A, real_B, accum.real);

        if (kTransformB == ComplexTransform::kNone) {
            warp_mma(accum.imag, real_A, imag_B, accum.imag);
        } else {
            warp_mma(accum.imag, real_A, neg_imag_B, accum.imag);
        }

        if (kTransformA == ComplexTransform::kNone) {
            warp_mma(accum.imag, imag_A, real_B, accum.imag);
        } else {
            warp_mma(accum.imag, imag_A, neg_real_B, accum.imag);
        }

        if (kTransformA == ComplexTransform::kNone ^
            kTransformB == ComplexTransform::kNone) {
            warp_mma(accum.real, imag_A, imag_B, accum.real);
        } else {
            warp_mma(accum.real, imag_A, neg_imag_B, accum.real);
        }
    }

public:
    /// Perform a threadblock-scoped matrix multiply-accumulate
    CUTLASS_DEVICE
    void operator()(
            ///< problem size of GEMM
            int gemm_k_iterations,
            ///< destination accumulator tile
            FragmentC& accum,
            ///< iterator over A operand in global memory
            IteratorA iterator_A_real,
            ///< iterator over A operand in global memory
            IteratorA iterator_A_imag,
            ///< iterator over B operand in global memory
            IteratorB iterator_B_real,
            ///< iterator over B operand in global memory
            IteratorB iterator_B_imag,
            ///< initial value of accumulator
            FragmentC const& src_accum) {
        //
        // Prologue
        //

        // Issue several complete stages
        CUTLASS_PRAGMA_UNROLL
        for (int stage = 0; stage < Base::kStages - 1;
             ++stage, --gemm_k_iterations) {
            if (gemm_k_iterations == 0) {
                iterator_A_real.clear_mask();
                iterator_A_imag.clear_mask();
                iterator_B_real.clear_mask();
                iterator_B_imag.clear_mask();
            }

            iterator_A_real.set_iteration_index(0);
            iterator_A_imag.set_iteration_index(0);

            this->smem_iterator_A_.set_iteration_index(0);

            // LDGSTS for operand A
            CUTLASS_PRAGMA_UNROLL
            for (int j = 0; j < Detail::TBLDGSTSIterationsA; ++j) {
                typename IteratorA::AccessType* dst_ptr =
                        reinterpret_cast<typename IteratorA::AccessType*>(
                                this->smem_iterator_A_.get());

                CUTLASS_PRAGMA_UNROLL
                for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) {
                    int const kSrcBytes =
                            sizeof_bits<typename IteratorA::Element>::value *
                            IteratorA::ThreadMap::kElementsPerAccess /
                            IteratorA::kAccessesPerVector / 8;

                    bool pred_guard = iterator_A_real.valid();

                    auto src_ptr_real = iterator_A_real.get();
                    auto src_ptr_imag = iterator_A_imag.get();

                    cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
                            dst_ptr + v, src_ptr_real, pred_guard);

                    cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpA>(
                            dst_ptr + v +
                                    Base::SharedStorage::kImaginaryStrideA /
                                            IteratorA::ThreadMap::
                                                    kElementsPerAccess,
                            reinterpret_cast<char const*>(src_ptr_imag),
                            pred_guard);

                    ++iterator_A_real;
                    ++iterator_A_imag;
                }

                ++this->smem_iterator_A_;
            }

            iterator_B_real.set_iteration_index(0);
            iterator_B_imag.set_iteration_index(0);

            this->smem_iterator_B_.set_iteration_index(0);

            // LDGSTS for operand B
            CUTLASS_PRAGMA_UNROLL
            for (int j = 0; j < Detail::TBLDGSTSIterationsB; ++j) {
                typename IteratorB::AccessType* dst_ptr =
                        reinterpret_cast<typename IteratorB::AccessType*>(
                                this->smem_iterator_B_.get());

                CUTLASS_PRAGMA_UNROLL
                for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) {
                    int const kSrcBytes =
                            sizeof_bits<typename IteratorB::Element>::value *
                            IteratorB::ThreadMap::kElementsPerAccess /
                            IteratorB::kAccessesPerVector / 8;

                    bool pred_guard = iterator_B_real.valid();

                    auto src_ptr_real = iterator_B_real.get();
                    auto src_ptr_imag = iterator_B_imag.get();

                    cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
                            dst_ptr + v, src_ptr_real, pred_guard);

                    cutlass::arch::cp_async_zfill<kSrcBytes, kCacheOpB>(
                            dst_ptr + v +
                                    Base::SharedStorage::kImaginaryStrideB /
                                            IteratorB::ThreadMap::
                                                    kElementsPerAccess,
                            reinterpret_cast<char const*>(src_ptr_imag),
                            pred_guard);

                    ++iterator_B_real;
                    ++iterator_B_imag;
                }

                ++this->smem_iterator_B_;
            }

            // Move to the next stage
            iterator_A_real.add_tile_offset({0, 1});
            iterator_A_imag.add_tile_offset({0, 1});

            iterator_B_real.add_tile_offset({1, 0});
            iterator_B_imag.add_tile_offset({1, 0});

            this->smem_iterator_A_.add_tile_offset({0, 1});
            this->smem_iterator_B_.add_tile_offset({1, 0});

            // Inserts a memory fence between stages of cp.async instructions
            cutlass::arch::cp_async_fence();
        }

        // Perform accumulation in the 'd' output operand
        accum = src_accum;

        // Blocks until all but kStages-2 cp.async stages have committed.
        cutlass::arch::cp_async_wait<Base::kStages - 2>();
        __syncthreads();

        // Pair of fragments used to overlap shared memory loads and math
        // instructions

        WarpFragmentA warp_frag_real_A[2];
        WarpFragmentA warp_frag_imag_A[2];

        WarpFragmentB warp_frag_real_B[2];
        WarpFragmentB warp_frag_imag_B[2];

        this->warp_tile_iterator_A_.set_kgroup_index(0);
        this->warp_tile_iterator_B_.set_kgroup_index(0);

        this->warp_tile_iterator_A_.load(warp_frag_real_A[0]);
        this->warp_tile_iterator_A_.load_with_pointer_offset(
                warp_frag_imag_A[0], Base::SharedStorage::kImaginaryStrideA);

        this->warp_tile_iterator_B_.load(warp_frag_real_B[0]);
        this->warp_tile_iterator_B_.load_with_pointer_offset(
                warp_frag_imag_B[0], Base::SharedStorage::kImaginaryStrideB);

        ++this->warp_tile_iterator_A_;
        ++this->warp_tile_iterator_B_;

        if (gemm_k_iterations == 0) {
            iterator_A_real.clear_mask();
            iterator_A_imag.clear_mask();
            iterator_B_real.clear_mask();
            iterator_B_imag.clear_mask();
        }

        // Start issuing the first group of the next stage outside of the
        // mainloop
        copy_tiles_and_advance(iterator_A_real, iterator_A_imag,
                               iterator_B_real, iterator_B_imag);

        Operator warp_mma;

        int smem_write_stage_idx = Base::kStages - 1;
        int smem_read_stage_idx = 0;

        //
        // Mainloop
        //

        CUTLASS_GEMM_LOOP
        for (; gemm_k_iterations > (-Base::kStages + 1);) {
            //
            // Loop over GEMM K dimension
            //

            // Computes a warp-level GEMM on data held in shared memory
            // Each "warp_mma_k" refers to a warp-level matrix
            // multiply-accumulate
            CUTLASS_PRAGMA_UNROLL
            for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations;
                 ++warp_mma_k) {
                // Load warp-level tiles from shared memory, wrapping to k
                // offset if this is the last group as the case may be.

                this->warp_tile_iterator_A_.set_kgroup_index(
                        (warp_mma_k + 1) % Base::kWarpGemmIterations);
                this->warp_tile_iterator_B_.set_kgroup_index(
                        (warp_mma_k + 1) % Base::kWarpGemmIterations);

                this->warp_tile_iterator_A_.load(
                        warp_frag_real_A[(warp_mma_k + 1) % 2]);
                this->warp_tile_iterator_A_.load_with_pointer_offset(
                        warp_frag_imag_A[(warp_mma_k + 1) % 2],
                        Base::SharedStorage::kImaginaryStrideA);

                this->warp_tile_iterator_B_.load(
                        warp_frag_real_B[(warp_mma_k + 1) % 2]);
                this->warp_tile_iterator_B_.load_with_pointer_offset(
                        warp_frag_imag_B[(warp_mma_k + 1) % 2],
                        Base::SharedStorage::kImaginaryStrideB);

                ++this->warp_tile_iterator_A_;
                ++this->warp_tile_iterator_B_;

                // Issue global->shared copies for the next stage
                int group_start_iteration_A, group_start_iteration_B;

                if (warp_mma_k + 1 == Base::kWarpGemmIterations) {
                    group_start_iteration_A = 0;
                    group_start_iteration_B = 0;
                } else {
                    group_start_iteration_A =
                            (warp_mma_k + 1) * Detail::kAccessesPerGroupA;
                    group_start_iteration_B =
                            (warp_mma_k + 1) * Detail::kAccessesPerGroupB;
                }

                copy_tiles_and_advance(iterator_A_real, iterator_A_imag,
                                       iterator_B_real, iterator_B_imag,
                                       group_start_iteration_A,
                                       group_start_iteration_B);

                if (warp_mma_k + 2 == Base::kWarpGemmIterations) {
                    // Inserts a memory fence between stages of cp.async
                    // instructions
                    cutlass::arch::cp_async_fence();

                    // Blocks until all but kStages-2 cp.async stages have
                    // committed.
                    arch::cp_async_wait<Base::kStages - 2>();
                    __syncthreads();

                    // Move to the next stage
                    iterator_A_real.add_tile_offset({0, 1});
                    iterator_A_imag.add_tile_offset({0, 1});

                    iterator_B_real.add_tile_offset({1, 0});
                    iterator_B_imag.add_tile_offset({1, 0});

                    this->smem_iterator_A_.add_tile_offset({0, 1});
                    this->smem_iterator_B_.add_tile_offset({1, 0});

                    // Add negative offsets to return iterators to the 'start'
                    // of the circular buffer in shared memory
                    if (smem_write_stage_idx == (Base::kStages - 1)) {
                        this->smem_iterator_A_.add_tile_offset(
                                {0, -Base::kStages});
                        this->smem_iterator_B_.add_tile_offset(
                                {-Base::kStages, 0});
                        smem_write_stage_idx = 0;
                    } else {
                        ++smem_write_stage_idx;
                    }

                    if (smem_read_stage_idx == (Base::kStages - 1)) {
                        this->warp_tile_iterator_A_.add_tile_offset(
                                {0, -Base::kStages * Policy::kPartitionsK *
                                            Base::kWarpGemmIterations});

                        this->warp_tile_iterator_B_.add_tile_offset(
                                {-Base::kStages * Policy::kPartitionsK *
                                         Base::kWarpGemmIterations,
                                 0});
                        smem_read_stage_idx = 0;
                    } else {
                        ++smem_read_stage_idx;
                    }

                    --gemm_k_iterations;
                    if (gemm_k_iterations == 0) {
                        iterator_A_real.clear_mask();
                        iterator_A_imag.clear_mask();
                        iterator_B_real.clear_mask();
                        iterator_B_imag.clear_mask();
                    }
                }

                warp_mma_planar_complex(warp_mma, accum,
                                        warp_frag_real_A[warp_mma_k % 2],
                                        warp_frag_imag_A[warp_mma_k % 2],
                                        warp_frag_real_B[warp_mma_k % 2],
                                        warp_frag_imag_B[warp_mma_k % 2]);
            }
        }
    }
};

/////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace threadblock
}  // namespace gemm
}  // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////
